import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy import stats
from pandas import  DataFrame
from sklearn.base import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold, cross_val_score, GroupKFold
from sklearn.metrics import mean_squared_error, r2_score

#----------------------------------------------------------------------------------------------------------------------------------
#                             Some variables I might want
#----------------------------------------------------------------------------------------------------------------------------------


#Some of the training data will have similar plot parameters; keep just one copy here for simplicity
ticksMajorOxo = [1, 3, 5, 7, 9, 11, 13]
ticksMinorOxo = [1.5, 2, 2.5, 3.5, 4, 4.5, 5.5, 6, 6.5, 7.5, 8, 8.5, 9.5, 10, 10.5, 11.5, 12, 12.5]
ticksMajorOxoLow = [-1, 1, 3, 5, 7, 9, 11]
ticksMinorOxoLow = [-0.5, 0, 0.5, 1.5, 2, 2.5, 3.5, 4, 4.5, 5.5, 6, 6.5, 7.5, 8, 8.5, 9.5, 10, 10.5]

ticksMajorCoO = [-3, -1, 1, 3, 5, 7]
ticksMinorCoO = [-2.5, -2, -1.5, -0.5, 0, 0.5, 1.5, 2, 2.5, 3.5, 4, 4.5, 5.5, 6, 6.5]

ticksMajorRuO = [0, 2, 4, 6, 8, 10, 12]
ticksMinorRuO = [0.5, 1, 1.5, 2.5, 3, 3.5, 4.5, 5, 5.5, 6.5, 7, 7.5, 8.5, 9, 9.5, 10.5, 11, 11.5]


substrates = ["Dihydroanthracene", "Cyclohexadiene", "Xanthene", "Fluorene"]
subs = ["DHA", "CHD", "Xth", "Fl"]
subMarks = ['ko', 'gX', 'bD', 'r^']

#----------------------------------------------------------------------------------------------------------------------------------
#                             "Compressing" Data for substrates of interest
#----------------------------------------------------------------------------------------------------------------------------------

def compressSubstrates(uncompressed, xVals):
	
	allSubData = np.empty((0,len(xVals)))

	subStarts = [0]

	for i in range(len(subs)):
		
		vals = []
		for val in xVals:
			if val[0:3] == "Sub":
				vals.append(subs[i]+val[3:])
			else:
				vals.append(val)
		
		hasData = uncompressed[uncompressed[subs[i]+" PCET Barrier"].notnull()]
		allSubData = np.append(allSubData, (hasData)[vals].values, axis=0)
		subStarts.append(subStarts[-1]+hasData.shape[0])

	return DataFrame(data = allSubData, columns = xVals), np.array(subStarts)
	

#----------------------------------------------------------------------------------------------------------------------------------
#                             Statistics Functions
#----------------------------------------------------------------------------------------------------------------------------------

#Getting the 5-fold CV Error of a model
#Returns an array of all the cvs
def ErrorsCV(model, X, y, threshold = 0.1):
	cvs = np.zeros(0)
	converged = False
	i = 0
	while not converged:
		cvs = np.append(cvs, -cross_val_score(model, X, y, cv=KFold(5, shuffle=True, random_state=i), scoring='neg_mean_squared_error').mean())
		if cvs.size > 10 and cvs.std()/np.sqrt(cvs.size) < threshold:
				converged = True
		i += 1
	return cvs
	
#Determines the LOO predicted values of y and returns them
def LeaveOneGroupOut(model, X, y, grouping):
	
	model = clone(model)
	predicted = np.zeros(y.size)
	leftOut = []
	
	for i in range(y.size):
		if not (i in leftOut):
			
			group = (grouping == grouping[i]).nonzero()[0]
			notgroup = (~(grouping == grouping[i])).nonzero()[0]
			
			XOO = (X[notgroup,:])
			yoo = y[notgroup]
			XLO = (X[group,:])
			
			model.fit(XOO, yoo)
			predictions = model.predict(XLO)
			
			for j in range(group.size):
				k = group[j]
				predicted[k] = predictions[j]
				leftOut.append(k)
		
	return predicted
	
#Determines the LOO predicted values of y and returns them
def LeaveOneOut(model, X, y):
	
	model = clone(model)
	predicted = []
	
	for i in range(y.size):
		XOO = np.append( X[:i,:], X[i+1:,:], axis=0)
		yoo = np.append( y[:i], y[i+1:] )
		
		model.fit(XOO, yoo)
		predicted.append(model.predict(np.array([X[i,:]]))[0])
		
	return np.array(predicted)
	
#Running an F-test on a model
#Based of of the formula in the Econometrics book, ddf p. 252
#Returns if there is 1-alpha confidence that nullX is insufficient
def FTest(model, nullX, allX, y, alpha = 0.05, extraConstraints = 0):
	
	nullModel = clone(model)
	altModel = clone(model)
	
	#Get the null (restricted)
	#nullX might be None, in which case you just use the average
	if nullX is None:
		errs = y - y.mean()
		nullSSE = errs @ errs
	else:
		nullModel.fit(nullX, y)
		errs = y - nullModel.predict(nullX)
		nullSSE = errs @ errs
	
	#Now get the alternate (unrestricted)
	altModel.fit(allX, y)
	errs = y - altModel.predict(allX)
	altSSE = errs @ errs
	
	#Now we do some F-ing stats
	N = y.size - extraConstraints
	K = allX.shape[1] + 1
	J = K - (1 if (nullX is None) else (nullX.shape[1]+1))
	F = ( (nullSSE - altSSE)/J ) / ( altSSE/(N-K) )
	crit =  stats.f.ppf(1-alpha, J, N-K)
	
	return F > crit, 1-stats.f.cdf(F, J, N-K)

#Confidence interval with tstatistics
#Uses a mix of scipy, sklearn, and formulas from the Freedman book and Econometrics book
#Note that it is probably only accurate if model is an instance of LinearRegression
#Returns:
#	Standard errors
#	Standard erros scaled by the relevant t-statistic
#	The best estimate
def tInterval(model, X, y, alpha = 0.05):

	model = clone(model)
	model.fit(X, y)
	errs = y - model.predict(X)
	params = np.append(np.array([model.intercept_]), model.coef_)
	
	X = np.append(np.array([np.ones(X.shape[0])]).T, X, axis=1)
	covEst = errs.T @ errs / (X.shape[0] - X.shape[1])
	covs = np.einsum('ii->i', covEst * np.linalg.inv( X.T @ X ))
	_, tStat = stats.t.interval(1-alpha, X.shape[0]-X.shape[1])
	stds = np.sqrt(covs)
	
	return stds, tStat*stds, params
		
def fitAndEvaluate(xVals, FJ, train, yVal = "DHA Barrier Effec", grouping = None):
	
	model = LinearRegression()
	X = train[xVals].values
	y = train[yVal].values
	
	if grouping is None:
		if FJ < len(xVals):
			F, p = FTest(model, X[:,:-FJ], X, y)
		else:
			F, p = FTest(model, None, X, y)
		stds, tErr, params = tInterval(model, X, y)
		cvErrs = ErrorsCV(model, X, y)
		looPredictions = LeaveOneOut(model, X, y)
	else:
		F, p = None, None
		stds, tErr, params = tInterval(model, X, y)
		cvErrs = None
		looPredictions = LeaveOneGroupOut(model, X, y, train[grouping].values)
	
	model.fit(X, y)	


	print("\n", xVals, "Metrics:\n")
	
#	print(train[yVal])
#	print(model.predict(X))
#	print()
	
	print("Score on Training Data:\t\t\t" + str(model.score(X, y)))
	print("MSE of Training Data:\t\t\t" + str(mean_squared_error(y, model.predict(X))))
	if grouping is None:
		print("Score of LOO Cross Validation:\t\t" + str(r2_score(y, looPredictions)))
		print("MSE of LOO Cross Validation:\t\t" + str(mean_squared_error(y, looPredictions)))
		print("MSE of 5-Fold Cross Validation:\t\t" + str(cvErrs.mean()) + "("+str(cvErrs.std()/np.sqrt(cvErrs.size))+")")
		print("F-Test p-value of final "+str(FJ)+" variables:\t" + str(p))
	else:
		print("Score of LOO Cross Validation:\t\t" + str(r2_score(y, looPredictions)))
		print("MSE of LOO Cross Validation:\t\t" + str(mean_squared_error(y, looPredictions)))
	print()
	print("Correlation Matrix of x-values:")
	print(train[xVals].corr(), "\n")
	print(xVals, "Training Average: \n" + str(X.mean(axis=0)))
	print("\n", xVals, "Training Deviation: \n" + str(X.std(axis=0)))
	print("\n", xVals, "Coefficients:\n" + str(model.coef_))
	print("\n", xVals, "Standard Error:\n", stds[1:])
	print("\n", xVals, "t-Test \"Error\":\n", tErr[1:])
	print("\n", xVals, "Weighted Coefficients: \n" + str(X.std(axis=0)*model.coef_))
	print("\n", xVals, "Intercept: \n", model.intercept_)
	print("\n\n")
	
	return model, F, p, tErr, cvErrs, looPredictions

#----------------------------------------------------------------------------------------------------------------------------------
#                             Plotting
#----------------------------------------------------------------------------------------------------------------------------------

def plotModel(xVals, model, title, train, xTicksMajor=None, xTicksMinor=None, yTicksMajor=None, yTicksMinor=None, limits=[2,12], test=None, yVal="DHA Barrier Effec", label="Training Set", marker='ko', figAxs = None, figName = None):
	if figAxs is None:
		fig, axs = plt.subplots(constrained_layout=True)
	else:
		fig = figAxs[0]
		axs = figAxs[1]
	axs.plot(limits, limits, '-', color='grey')
	axs.plot(train[yVal], model.predict(train[xVals]), marker, label=label)
		
	if isinstance(test, DataFrame):
		axs.plot(test[yVal], model.predict(test[xVals]), 'rs', label='Test Data')
	
	#axs.set_title(title)
	axs.set_xlabel('Experimental Effective Barrier (kcal/mol)')
	axs.set_ylabel('Predicted Effective Barrier (kcal/mol)')
	fig.canvas.set_window_title(title)
	axs.legend(loc='upper left')
	
	if xTicksMajor != None:
		axs.set_xticks(xTicksMajor)
		axs.set_xlim(xTicksMajor[0], xTicksMajor[-1])
	if xTicksMinor != None:
		axs.set_xticks(xTicksMinor, minor=True)
	if yTicksMajor != None:
		axs.set_yticks(yTicksMajor)
		axs.set_ylim(yTicksMajor[0], yTicksMajor[-1])
	if yTicksMinor != None:
		axs.set_yticks(yTicksMinor, minor=True)

	if figName is None:
		fig.show()
	else:
		fig.savefig(figName+".svg", dpi=600, format='svg')
	
	return fig, axs
